Variational encoders are separately trained for each bird. To determine the optimal number of embedding dimensions, I calculated the Calinski-Harabasz index, or the ratio of the between-cluster variance to the within-cluster variance, using the pre-labelled clusters (fig @ref(fig:vae-dimensions)). Bird 7358 (66-68 DPH) has relatively stable syllables and song syntax, while bird 6951 (59-63 DPH) has more variable syllables and syntax @ref(fig:syntax-graphs). For bird 7358, little information is gained beyond 32 dimensions.
library(plotly)
trainingLoss <- read.table("Data/trainingLoss.csv", header = T, sep = ",")
trainingLoss <- trainingLoss[!is.na(trainingLoss$bird),]
create_plot <- function(tbl, title, overlayingAxis){
p <- plot_ly() %>%
add_trace(data = tbl, x = ~dimensions, y = ~RL, name = "RL",
yaxis = 'y2', type = 'scatter', mode = 'lines+markers') %>%
add_trace(data = tbl, x = ~dimensions, y = ~CalinskiHarabasz, name = "Calinski-Harabasz Index",
yaxis = 'y1', type = 'scatter', mode = 'lines+markers') %>%
layout(showlegend = FALSE,
annotations = list(x=0.5, y=1.05, text=title, font=list(size=18), showarrow=F, xref='paper', yref='paper',xanchor='center'),
colorway=c("#0484bf", "#ff7f0e"),
hovermode = "x unified",
xaxis = list(title="Embedded dimensions", zeroline=F, showline=T, linewidth=2, linecolor='black', showgrid=F,
tickvals = list(8,16,32,64,128), ticks="inside"),
yaxis = list(showgrid=F, showline=T, zeroline=F, tickfont=list(color = '#ff7f0e', size=11), color='#ff7f0e',
title="Calinski-Harabasz Index",
linewidth=2,
overlaying = "n", side = "right"),
yaxis2 = list(showgrid=F, showline=T, zeroline=F, tickfont=list(color='#0484bf', size=11), color='#0484bf',
title="Reconstruction Loss",
linewidth=2,
overlaying = overlayingAxis, side = "left"))
return(p)
}
fig1 <- create_plot(trainingLoss[trainingLoss$bird == 7358,], title="Bird 7358", overlayingAxis="y")
fig2 <- create_plot(trainingLoss[trainingLoss$bird == 6951,], title="Bird 6951", overlayingAxis="y3")
fig <- subplot(fig1, fig2, titleY = TRUE, titleX = TRUE, margin = 0.1 )
fig
Reconstruction loss and Calinski-Harabasz Index.
# fig1
# fig1
# DaviesBouldin
# CalinskiHarabasz
# silhouettesilhouette
trainingLoss <- read.table("Data/trainingLoss.csv", header = T, sep = ",")
trainingLoss <- trainingLoss[!is.na(trainingLoss$bird),]
trainingLoss$bird = as.character(trainingLoss$bird) #convert to string
create_plot <- function(tbl, title, overlayingAxis){
p <- plot_ly() %>%
add_trace(data = tbl, x = ~dimensions, y = ~f1, color = ~bird, colors = c("red", "blue"),
type = 'scatter', mode = 'lines+markers') %>%
layout(showlegend = TRUE,
legend=list(title=list(text='Bird')),
annotations = list(x=0.5, y=1.05, text=title, font=list(size=18), showarrow=F, xref='paper', yref='paper',xanchor='center'),
hovermode = "x unified",
xaxis = list(title="Embedded dimensions", zeroline=F, showline=T, linewidth=2, linecolor='black', showgrid=F,
tickvals = list(8,16,32,64,128), ticks="inside"),
yaxis = list(showgrid=F, showline=T, zeroline=F, tickfont=list(size=11),
title="F1 Score",
type = "log",
linewidth=2,
overlaying = "n", side = "left"))
return(p)
}
fig <- create_plot(trainingLoss, title="F1 Score", overlayingAxis="y")
fig
F1 Score for bird 7358 and 6951.
knitr::include_graphics('Figures/VAE/reconstructions-7358-32.png')
Input (left) and decoded (right) syllables.
knitr::include_graphics('Figures/VAE/traversal - 7358.png')
Traversing the embedding space from the centroid of syllable “i” to each other syllable centroid.
knitr::include_graphics('Figures/VAE/umap - 32 - bird 7358.png')
knitr::include_graphics('Figures/VAE/umap - 32 - bird 6951.png')
Syllable clusters from embedded dimensions.
library(htmlwidgets)
library(slickR)
library(stringr)
# cP1 <- htmlwidgets::JS("function(slick,index) {
# return '<a>'+(index+1)+'</a>';
# }")
opts <- settings(
dots = TRUE,
# customPaging = cP1,
speed = 0,
adaptiveHeight = FALSE)
filenames = list.files(path = 'Figures/VAE/',
pattern = 'umap - spikes - Bird 7358 - Session', full.names = T)
filenames = as.data.frame(str_match(filenames, "Figures/VAE/umap - spikes - Bird 7358 - Session (\\d+) - Unit (\\d+).png"))
colnames(filenames) = c("filename","session","neuron")
sets <- split(
filenames$filename,
filenames$session)
slicks <- lapply(sets,FUN = function(x,...){
slickR(obj = x,...) + opts
}, height = "90%", width = "90%", padding = 0)
Reduce(`%stack%`,slicks)